#!/usr/bin/env bash

# Copyright (c) 2023 ETH Zurich
#
# SPDX-License-Identifier: BSL-1.0
# Distributed under the Boost Software License, Version 1.0. (See accompanying
# file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

set -o pipefail

version_message="pika-bind from pika version @PIKA_VERSION@"

read -r -d '' help_message << EOM
Set PIKA_PROCESS_MASK environment variable based on process mask information found before the pika
runtime is started, and then runs the given command. pika-bind is useful in cases where pika cannot
detect the process mask anymore when starting the runtime. This can happen for example if another
runtime resets the process mask on the main thread before pika is initialized. See
https://pikacpp.org for more information.

Usage:
  pika-bind [-h|--help] [--version] [-v|--verbose] [-a|--allow-no-bind] [(-m|--method) (all|slurm|hwloc|taskset)] [-p|--print-bind]-- <command>"

Options:
  -h --help            Show this screen.
  --version            Show version.
  --verbose            Be verbose.

  -a --allow-no-bind   Allow pika-bind to continue even if it cannot determine a process mask.
  -m --method <method> Set the method used to determine the process mask. Can be one of all
                       (default), slurm, hwloc, or taskset. When set to all, all available methods
                       are tried in the order slurm, hwloc, taskset. The first applicable method
                       will be used. The slurm method looks for the SLURM_CPU_BIND environment
                       variable. The hwloc method requires the hwloc-bind binary. The taskset method
                       requires the taskset binary.
  -p --print-bind      Set PIKA_PRINT_BIND environment variable to print pika bindings.
EOM


# Script options
print_help=1
print_version=1
verbose=1
allow_no_bind=1
print_bind=1

valid_methods=("all" "slurm" "hwloc" "taskset")
method="all"

# Parse command line options
VALID_ARGS=$(getopt -o hvam:p --long help,version,verbose,allow-no-bind,method:,print-bind, -- "$@")
if [[ $? -ne 0 ]]; then
    exit 1;
fi

eval set -- "$VALID_ARGS"
while [ : ]; do
  case "$1" in
    -h | --help)
        print_help=0
        shift
        ;;
    --version)
        print_version=0
        shift
        ;;
    -v | --verbose)
        verbose=0
        shift
        ;;
    -a | --allow-no-bind)
        allow_no_bind=0
        shift
        ;;
    -m | --method)
        method=$2
        shift 2
        ;;
    -p | --print-bind)
        print_bind=0
        shift
        ;;
    --) shift;
        break
        ;;
  esac
done

# Helper functions
function pika-bind-verbose {
    if [[ ${verbose} -eq 0 ]]; then
        >&2 echo "pika-bind: INFO: $1"
    fi
}

function pika-bind-warning {
    >&2 echo "pika-bind: WARNING: $1"
}

function pika-bind-error {
    echo "pika-bind: ERROR: $1"
    exit 1
}

function method-enabled {
    if [[ "${method}" == "all" ]]; then
        pika-bind-verbose "Trying method \"${1}\" because method is set to \"all\""
        return 0
    elif [[ "${1}" == "${method}" ]]; then
        pika-bind-verbose "Trying method \"${1}\" because method is set to \"${method}\""
        return 0
    else
        pika-bind-verbose "Skipping method \"${1}\" because method is set to \"${method}\""
        return 1
    fi
}

# Process --version flag
if [[ ${print_version} -eq 0 ]]; then
    echo "$version_message"
    exit 0
fi

# Process --help flag
if [[ ${print_help} -eq 0 ]]; then
    echo "$help_message"
    exit 0
fi

method_pattern="\<${method}\>"
if [[ ! ${valid_methods[@]} =~ ${method_pattern} ]]; then
    pika-bind-error "Invalid method: ${method}. See pika-bind --help for available methods."
    exit 1
fi

if [[ ${print_bind} -eq 0 ]]; then
    pika-bind-verbose "Setting PIKA_PRINT_BIND variable. pika will print bindings."
    export PIKA_PRINT_BIND=
fi

# Main script
if [[ ! -z "${PIKA_PROCESS_MASK}" ]]; then
    pika-bind-verbose "PIKA_PROCESS_MASK is already set to \"${PIKA_PROCESS_MASK}\", not setting it again"
fi

# Expecting:
# quiet,mask_cpu:0x00003FFFF00003FFFF,0xFFFFC0000FFFFC0000
if method-enabled "slurm" && [[ -z "${PIKA_PROCESS_MASK}" ]]; then
    pika-bind-verbose "Trying to use slurm CPU mask through SLURM_CPU_BIND=\"${SLURM_CPU_BIND}\""

    if [[ ! -z "${SLURM_CPU_BIND}" ]]; then
       slurm_mask=$(echo "$SLURM_CPU_BIND" |
           awk -F: '{print $2}' |
           awk -F, "{print \$(1 + ${SLURM_LOCALID})}")
       if [[ "${slurm_mask}" =~ 0x[0-9a-f]+ ]]; then
           export PIKA_PROCESS_MASK=${slurm_mask}
           pika-bind-verbose "Setting PIKA_PROCESS_MASK=\"${PIKA_PROCESS_MASK}\" using mask from SLURM_CPU_BIND"
       else
           pika-bind-verbose "Skipping slurm because SLURM_CPU_BIND not in expected format. Expecting format \"option1,option2,...:0x<mask1>,0x<mask2>,...\"."
       fi
    else
       pika-bind-verbose "Skipping slurm because SLURM_CPU_BIND is empty or unset"
    fi
fi

# Expecting:
# 0x7f00000007f
if method-enabled "hwloc" && [[ -z "${PIKA_PROCESS_MASK}" ]]; then
    pika-bind-verbose "Trying to use hwloc-bind CPU mask"
    hwloc_output=$(hwloc-bind --get --taskset)
    hwloc_exit_code=$?
    if [[ $hwloc_exit_code -eq 0 ]]; then
       if [[ "${hwloc_output}" =~ 0x[0-9a-f]+ ]]; then
           export PIKA_PROCESS_MASK=${hwloc_output}
           pika-bind-verbose "Setting PIKA_PROCESS_MASK=\"${PIKA_PROCESS_MASK}\" using mask from hwloc-bind"
       else
           pika-bind-verbose "Skipping hwloc-bind because output not in expected format. Expecting format \"0x<mask>\"."
       fi
    else
       pika-bind-verbose "Skipping hwloc-bind because hwloc-bind failed with exit code $hwloc_exit_code and output \"$hwloc_output\""
    fi
fi

# Expecting:
# pid 19304's current affinity mask: 7f00000007f
if method-enabled "taskset" && [[ -z "${PIKA_PROCESS_MASK}" ]]; then
   taskset_output=$(taskset --pid $$)
   taskset_exit_code=$?
   if [[ ${taskset_exit_code} -eq 0 ]]; then
       taskset_mask=$(echo "$taskset_output" | awk -F: '{gsub(/ /,""); print $2}')
       if [[ "${taskset_mask}" =~ [0-9a-f]+ ]]; then
           export PIKA_PROCESS_MASK=0x${taskset_mask}
           pika-bind-verbose "Setting PIKA_PROCESS_MASK=\"${PIKA_PROCESS_MASK}\" using mask from taskset"
       else
           pika-bind-verbose "Skipping taskset because output not in expected format. Expecting format \"pid 12345's current affinity mask: <mask>\"."
       fi
   else
       pika-bind-verbose "Skipping taskset because taskset failed with exit code $taskset_exit_code and output \"$taskset_output\""
   fi
fi

if [[ -z "${PIKA_PROCESS_MASK}" ]]; then
    if [[ ${allow_no_bind} -ne 0 ]]; then
        pika-bind-error "Failed to set PIKA_PROCESS_MASK using chosen method (\"${method}\"). Exiting because --allow-no-bind is unset. Use --allow-no-bind to ignore failure to set PIKA_PROCESS_MASK."
    else
        pika-bind-warning "Failed to set PIKA_PROCESS_MASK using chosen method (\"${method}\"). Continuing without a mask because because --allow-no-bind is set."
    fi
fi

pika-bind-verbose "Executing command: \"$@\""

"$@"
